{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 10.7 文本情感分类：使用循环神经网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T04:26:23.247619Z",
     "start_time": "2019-07-03T04:26:20.949830Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.1.0 cuda\n"
     ]
    }
   ],
   "source": [
    "import collections\n",
    "import os\n",
    "import random\n",
    "import tarfile\n",
    "import torch\n",
    "from torch import nn\n",
    "import torchtext.vocab as Vocab\n",
    "import torch.utils.data as Data\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\") \n",
    "import d2lzh_pytorch as d2l\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "DATA_ROOT = \"/data1/tangss/Datasets\"\n",
    "\n",
    "print(torch.__version__, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10.7.1 文本情感分类数据\n",
    "### 10.7.1.1 读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T04:26:23.255913Z",
     "start_time": "2019-07-03T04:26:23.250957Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "fname = os.path.join(DATA_ROOT, \"aclImdb_v1.tar.gz\")\n",
    "if not os.path.exists(os.path.join(DATA_ROOT, \"aclImdb\")):\n",
    "    print(\"从压缩包解压...\")\n",
    "    with tarfile.open(fname, 'r') as f:\n",
    "        f.extractall(DATA_ROOT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T04:26:39.257587Z",
     "start_time": "2019-07-03T04:26:23.258808Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 12500/12500 [00:00<00:00, 34211.42it/s]\n",
      "100%|██████████| 12500/12500 [00:00<00:00, 38506.48it/s]\n",
      "100%|██████████| 12500/12500 [00:00<00:00, 31316.61it/s]\n",
      "100%|██████████| 12500/12500 [00:00<00:00, 29664.72it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "def read_imdb(folder='train', data_root=\"/S1/CSCL/tangss/Datasets/aclImdb\"):  # 本函数已保存在d2lzh_pytorch包中方便以后使用\n",
    "    data = []\n",
    "    for label in ['pos', 'neg']:\n",
    "        folder_name = os.path.join(data_root, folder, label)\n",
    "        for file in tqdm(os.listdir(folder_name)):\n",
    "            with open(os.path.join(folder_name, file), 'rb') as f:\n",
    "                review = f.read().decode('utf-8').replace('\\n', '').lower()\n",
    "                data.append([review, 1 if label == 'pos' else 0])\n",
    "    random.shuffle(data)\n",
    "    return data\n",
    "\n",
    "data_root = os.path.join(DATA_ROOT, \"aclImdb\")\n",
    "train_data, test_data = read_imdb('train', data_root), read_imdb('test', data_root)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.7.1.2 预处理数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T04:26:39.262666Z",
     "start_time": "2019-07-03T04:26:39.259588Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_tokenized_imdb(data):  # 本函数已保存在d2lzh_pytorch包中方便以后使用\n",
    "    \"\"\"\n",
    "    data: list of [string, label]\n",
    "    \"\"\"\n",
    "    def tokenizer(text):\n",
    "        return [tok.lower() for tok in text.split(' ')]\n",
    "    return [tokenizer(review) for review, _ in data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T04:26:42.010298Z",
     "start_time": "2019-07-03T04:26:39.264464Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('# words in vocab:', 46152)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_vocab_imdb(data):  # 本函数已保存在d2lzh_pytorch包中方便以后使用\n",
    "    tokenized_data = get_tokenized_imdb(data)\n",
    "    counter = collections.Counter([tk for st in tokenized_data for tk in st])\n",
    "    return Vocab.Vocab(counter, min_freq=5)\n",
    "\n",
    "vocab = get_vocab_imdb(train_data)\n",
    "'# words in vocab:', len(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T04:26:42.016214Z",
     "start_time": "2019-07-03T04:26:42.012406Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def preprocess_imdb(data, vocab):  # 本函数已保存在d2lzh_torch包中方便以后使用\n",
    "    max_l = 500  # 将每条评论通过截断或者补0，使得长度变成500\n",
    "\n",
    "    def pad(x):\n",
    "        return x[:max_l] if len(x) > max_l else x + [0] * (max_l - len(x))\n",
    "\n",
    "    tokenized_data = get_tokenized_imdb(data)\n",
    "    features = torch.tensor([pad([vocab.stoi[word] for word in words]) for words in tokenized_data])\n",
    "    labels = torch.tensor([score for _, score in data])\n",
    "    return features, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.7.1.3 创建数据迭代器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T04:26:47.614720Z",
     "start_time": "2019-07-03T04:26:42.017922Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "train_set = Data.TensorDataset(*preprocess_imdb(train_data, vocab))\n",
    "test_set = Data.TensorDataset(*preprocess_imdb(test_data, vocab))\n",
    "train_iter = Data.DataLoader(train_set, batch_size, shuffle=True)\n",
    "test_iter = Data.DataLoader(test_set, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T04:26:47.624512Z",
     "start_time": "2019-07-03T04:26:47.616891Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X torch.Size([64, 500]) y torch.Size([64])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('#batches:', 391)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for X, y in train_iter:\n",
    "    print('X', X.shape, 'y', y.shape)\n",
    "    break\n",
    "'#batches:', len(train_iter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10.7.2 使用循环神经网络的模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T04:26:47.630109Z",
     "start_time": "2019-07-03T04:26:47.625789Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class BiRNN(nn.Module):\n",
    "    def __init__(self, vocab, embed_size, num_hiddens, num_layers):\n",
    "        super(BiRNN, self).__init__()\n",
    "        self.embedding = nn.Embedding(len(vocab), embed_size)\n",
    "        \n",
    "        # bidirectional设为True即得到双向循环神经网络\n",
    "        self.encoder = nn.LSTM(input_size=embed_size, \n",
    "                                hidden_size=num_hiddens, \n",
    "                                num_layers=num_layers,\n",
    "                                bidirectional=True)\n",
    "        self.decoder = nn.Linear(4*num_hiddens, 2) # 初始时间步和最终时间步的隐藏状态作为全连接层输入\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        # inputs的形状是(批量大小, 词数)，因为LSTM需要将序列长度(seq_len)作为第一维，所以将输入转置后\n",
    "        # 再提取词特征，输出形状为(词数, 批量大小, 词向量维度)\n",
    "        embeddings = self.embedding(inputs.permute(1, 0))\n",
    "        # rnn.LSTM只传入输入embeddings，因此只返回最后一层的隐藏层在各时间步的隐藏状态。\n",
    "        # outputs形状是(词数, 批量大小, 2 * 隐藏单元个数)\n",
    "        outputs, _ = self.encoder(embeddings) # output, (h, c)\n",
    "        # 连结初始时间步和最终时间步的隐藏状态作为全连接层输入。它的形状为\n",
    "        # (批量大小, 4 * 隐藏单元个数)。\n",
    "        encoding = torch.cat((outputs[0], outputs[-1]), -1)\n",
    "        outs = self.decoder(encoding)\n",
    "        return outs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T04:26:47.684133Z",
     "start_time": "2019-07-03T04:26:47.631441Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "embed_size, num_hiddens, num_layers = 100, 100, 2\n",
    "net = BiRNN(vocab, embed_size, num_hiddens, num_layers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.7.2.1 加载预训练的词向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T04:26:47.895604Z",
     "start_time": "2019-07-03T04:26:47.685801Z"
    }
   },
   "outputs": [],
   "source": [
    "glove_vocab = Vocab.GloVe(name='6B', dim=100, cache=os.path.join(DATA_ROOT, \"glove\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T04:26:48.102388Z",
     "start_time": "2019-07-03T04:26:47.897582Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 21202 oov words.\n"
     ]
    }
   ],
   "source": [
    "def load_pretrained_embedding(words, pretrained_vocab):\n",
    "    \"\"\"从预训练好的vocab中提取出words对应的词向量\"\"\"\n",
    "    embed = torch.zeros(len(words), pretrained_vocab.vectors[0].shape[0]) # 初始化为0\n",
    "    oov_count = 0 # out of vocabulary\n",
    "    for i, word in enumerate(words):\n",
    "        try:\n",
    "            idx = pretrained_vocab.stoi[word]\n",
    "            embed[i, :] = pretrained_vocab.vectors[idx]\n",
    "        except KeyError:\n",
    "            oov_count += 1\n",
    "    if oov_count > 0:\n",
    "        print(\"There are %d oov words.\" % oov_count)\n",
    "    return embed\n",
    "\n",
    "net.embedding.weight.data.copy_(load_pretrained_embedding(vocab.itos, glove_vocab))\n",
    "net.embedding.weight.requires_grad = False # 直接加载预训练好的, 所以不需要更新它"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.7.2.2 训练并评价模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T04:47:57.808046Z",
     "start_time": "2019-07-03T04:26:48.104185Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 0.5415, train acc 0.719, test acc 0.819, time 48.7 sec\n",
      "epoch 2, loss 0.1897, train acc 0.837, test acc 0.852, time 53.0 sec\n",
      "epoch 3, loss 0.1105, train acc 0.857, test acc 0.844, time 51.6 sec\n",
      "epoch 4, loss 0.0719, train acc 0.881, test acc 0.865, time 52.1 sec\n",
      "epoch 5, loss 0.0519, train acc 0.894, test acc 0.852, time 51.2 sec\n"
     ]
    }
   ],
   "source": [
    "lr, num_epochs = 0.01, 5\n",
    "optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr)\n",
    "loss = nn.CrossEntropyLoss()\n",
    "d2l.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T04:47:57.813888Z",
     "start_time": "2019-07-03T04:47:57.810244Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 本函数已保存在d2lzh包中方便以后使用\n",
    "def predict_sentiment(net, vocab, sentence):\n",
    "    \"\"\"sentence是词语的列表\"\"\"\n",
    "    device = list(net.parameters())[0].device\n",
    "    sentence = torch.tensor([vocab.stoi[word] for word in sentence], device=device)\n",
    "    label = torch.argmax(net(sentence.view((1, -1))), dim=1)\n",
    "    return 'positive' if label.item() == 1 else 'negative'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T04:47:57.829262Z",
     "start_time": "2019-07-03T04:47:57.815487Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'positive'"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predict_sentiment(net, vocab, ['this', 'movie', 'is', 'so', 'great'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T04:47:57.838439Z",
     "start_time": "2019-07-03T04:47:57.830707Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'negative'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predict_sentiment(net, vocab, ['this', 'movie', 'is', 'so', 'bad'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py36_pytorch]",
   "language": "python",
   "name": "conda-env-py36_pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
